450.4 Final Project¶

Methods:¶

In [1]:
# Load the required packages.

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn import metrics
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

import tensorflow as tf
from keras import backend as K
from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, LeakyReLU, BatchNormalization
from keras.metrics import AUC
In [2]:
# Load the dataset.

COVID_data = pd.read_csv('COVID_data.csv')
In [3]:
COVID_data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10020 entries, 0 to 10019
Data columns (total 16 columns):
 #   Column                                  Non-Null Count  Dtype 
---  ------                                  --------------  ----- 
 0   Age.at.diagnosis                        10020 non-null  object
 1   Sex                                     10020 non-null  object
 2   Month.first.diagnosis                   10020 non-null  object
 3   Year.first.diagnosis                    10020 non-null  int64 
 4   Uncomplicated.phase                     10020 non-null  object
 5   Complicated.phase                       10020 non-null  object
 6   Critical.phase                          10020 non-null  object
 7   Recovery.phase                          10020 non-null  object
 8   Vasopressors.in.complicated.phase       4725 non-null   object
 9   Vasopressors.in.critical.phase          1856 non-null   object
 10  Invasive.ventilation.in.critical.phase  1856 non-null   object
 11  Superinfection.in.uncomplicated.phase   8064 non-null   object
 12  Superinfection.in.complicated.phase     4725 non-null   object
 13  Superinfection.in.critical.phase        1856 non-null   object
 14  Symptoms.in.recovery.phase              4941 non-null   object
 15  Last.known.patient.status               10020 non-null  object
dtypes: int64(1), object(15)
memory usage: 1.2+ MB
In [4]:
# Fill any blanks in the data with n/a.

COVID_data = COVID_data.fillna("n/a")
In [5]:
COVID_data.head(20)
Out[5]:
Age.at.diagnosis Sex Month.first.diagnosis Year.first.diagnosis Uncomplicated.phase Complicated.phase Critical.phase Recovery.phase Vasopressors.in.complicated.phase Vasopressors.in.critical.phase Invasive.ventilation.in.critical.phase Superinfection.in.uncomplicated.phase Superinfection.in.complicated.phase Superinfection.in.critical.phase Symptoms.in.recovery.phase Last.known.patient.status
0 26 - 45 years Female <= 3 2020 yes no no no n/a n/a n/a none n/a n/a n/a Recovered
1 46 - 65 years Female <= 3 2020 yes yes no yes no n/a n/a none none n/a yes Recovered
2 <= 25 years Female <= 3 2020 yes no no no n/a n/a n/a none n/a n/a n/a Recovered
3 46 - 65 years Male <= 3 2020 yes no no yes n/a n/a n/a none n/a n/a no Recovered
4 26 - 45 years Female <= 3 2020 yes no no no n/a n/a n/a none n/a n/a n/a Recovered
5 46 - 65 years Female <= 3 2020 yes no no no n/a n/a n/a none n/a n/a n/a Recovered
6 46 - 65 years Male <= 3 2020 yes no no no n/a n/a n/a unknown/missing n/a n/a n/a Not recovered
7 46 - 65 years Male <= 3 2020 yes no no no n/a n/a n/a none n/a n/a n/a Recovered
8 46 - 65 years Male <= 3 2020 yes no no no n/a n/a n/a unknown/missing n/a n/a n/a Not recovered
9 46 - 65 years Female <= 3 2020 yes no no no n/a n/a n/a unknown/missing n/a n/a n/a Not recovered
10 46 - 65 years Female <= 3 2020 yes yes no no no n/a n/a none none n/a n/a Recovered
11 46 - 65 years Male <= 3 2020 yes no no no n/a n/a n/a none n/a n/a n/a Recovered
12 26 - 45 years Male <= 3 2020 yes no no no n/a n/a n/a none n/a n/a n/a Recovered
13 26 - 45 years Female <= 3 2020 yes no no no n/a n/a n/a none n/a n/a n/a Recovered
14 26 - 45 years Male <= 3 2020 yes no no no n/a n/a n/a none n/a n/a n/a Recovered
15 26 - 45 years Male <= 3 2020 yes no no no n/a n/a n/a none n/a n/a n/a unknown/missing
16 26 - 45 years Male <= 3 2020 yes no no no n/a n/a n/a unknown/missing n/a n/a n/a Recovered
17 46 - 65 years Male <= 3 2020 yes no yes no n/a yes yes none n/a bacterial n/a Recovered
18 46 - 65 years Male <= 3 2020 yes yes yes no yes yes yes none none none n/a Recovered
19 46 - 65 years Male <= 3 2020 yes no yes no n/a yes yes none n/a none n/a Recovered

Visualizations:¶

In [6]:
# Use count plots to visualize the distribution of the data.

sns.set(rc={'figure.figsize':(11.7,8.27)})

sns.countplot(x="Age.at.diagnosis", data= COVID_data, 
              order=['<= 25 years','26 - 45 years','46 - 65 years','66 - 85 years','> 85 years'])
plt.show()

# Most of the individuals in this survey are 46-85 years old.
## Very few are under 25, and very few are over 85.
In [7]:
sns.countplot(x="Sex", data= COVID_data)
plt.show()

# There are more males represented in this dataset than females.
In [8]:
sns.countplot(x="Month.first.diagnosis", data= COVID_data, 
              order=['1','2','<= 3','3','4','5','6','7','8','9','10','11','12'])
plt.show()

# Most patients were first diagnosed 4 or fewer months ago.
## A substantial portion were first diagnosed 10-12 months ago.
In [9]:
sns.countplot(x="Year.first.diagnosis", data= COVID_data)
plt.show()

# Most individuals in this survey were first diagnosed in 2020.
In [10]:
sns.countplot(x="Uncomplicated.phase", data= COVID_data, order=['no','yes'])
plt.show()

# The majority of individuals infected with COVID-19 experience the uncomplicated phase.
In [11]:
sns.countplot(x="Complicated.phase", data= COVID_data)
plt.show()

# A substantial portion of individuals experience the complicated phase.
In [12]:
sns.countplot(x="Critical.phase", data= COVID_data)
plt.show()

# The majority of individuals do not enter the critical phase.
In [13]:
sns.countplot(x="Recovery.phase", data= COVID_data)
plt.show()

# A substantial portion of individuals enter the recovery phase.
## NOTE: some patients may recover without being listed as having entered the recovery phase.
## Read more about how the four different phases are diagnosed here: https://leoss.net/statistics/.
In [14]:
sns.countplot(x="Vasopressors.in.complicated.phase", data= COVID_data[-(COVID_data == 'n/a')], 
              order=['no','yes', 'unknown/missing'])
plt.show()
In [15]:
sns.countplot(x="Vasopressors.in.critical.phase", data= COVID_data[-(COVID_data == 'n/a')], 
              order=['no','yes', 'unknown/missing'])
plt.show()

# Far more patients required vasopressors in the critical phase than in the complicated phase.
## Vasopressors are used to constrict blood vessels for people with low blood pressure.
In [16]:
sns.countplot(x="Invasive.ventilation.in.critical.phase", data= COVID_data[-(COVID_data == 'n/a')], 
              order=['no','yes', 'unknown/missing'])
plt.show()

# A substantial portion of individuals required invasive ventilation in the critical phase.
## Invasive ventilation involves inserting a tube in the throat to take over respiratory function.
In [17]:
sns.countplot(x="Superinfection.in.uncomplicated.phase", data= COVID_data[-(COVID_data == 'n/a')],
             order=['none','bacterial', 'bacterial&fungal', 'fungal', 'unknown/missing'])
plt.show()
In [18]:
sns.countplot(x="Superinfection.in.complicated.phase", data= COVID_data[-(COVID_data == 'n/a')],
             order=['none','bacterial', 'bacterial&fungal', 'fungal', 'unknown/missing'])
plt.show()
In [19]:
sns.countplot(x="Superinfection.in.critical.phase", data= COVID_data[-(COVID_data == 'n/a')],
             order=['none','bacterial', 'bacterial&fungal', 'fungal', 'unknown/missing'])
plt.show()

# A much larger portion of individuals in the critical phase experience superinfections.
## Bacterial superinfections tend to be the most common.
In [20]:
sns.countplot(x="Symptoms.in.recovery.phase", data= COVID_data[-(COVID_data == 'n/a')], 
              order=['no','yes', 'unknown/missing'])
plt.show()

# Most patients did not experience symptoms in the recovery phase.
In [21]:
sns.countplot(x="Last.known.patient.status", data= COVID_data[-(COVID_data == 'unknown/missing')])
plt.show()

# The majority of the patients in this dataset recovered from COVID-19.
## However, I am interested in predicting the chance that someone will die from COVID-19 based on this data.
## I will therefore be looking at those categorized as "Dead from COVID-19".

Data Cleaning:¶

In [22]:
# Make all variables numeric, on a scale of 0 to 1.
## Intermediate values should be evenly spaced from 0 to 1.
## I will be cleaning and scaling the data manually due to some inconsistencies in it.

COVID_data.loc[(COVID_data['Age.at.diagnosis'] == '<= 25 years'), 'Age.at.diagnosis'] = 0
COVID_data.loc[(COVID_data['Age.at.diagnosis'] == '26 - 45 years'), 'Age.at.diagnosis'] = .25
COVID_data.loc[(COVID_data['Age.at.diagnosis'] == '46 - 65 years'), 'Age.at.diagnosis'] = .5
COVID_data.loc[(COVID_data['Age.at.diagnosis'] == '66 - 85 years'), 'Age.at.diagnosis'] = .75
COVID_data.loc[(COVID_data['Age.at.diagnosis'] == '> 85 years'), 'Age.at.diagnosis'] = 1
In [23]:
COVID_data.loc[(COVID_data['Sex'] == 'Female'), 'Sex'] = 0
COVID_data.loc[(COVID_data['Sex'] == 'Male'), 'Sex'] = 1
In [24]:
# I will be combining "<= 3" with "2" for the purposes of this model.

COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '<= 3'), 'Month.first.diagnosis'] = 0.0909
COVID_data['Month.first.diagnosis'].astype(float)
Out[24]:
0         0.0909
1         0.0909
2         0.0909
3         0.0909
4         0.0909
          ...   
10015     8.0000
10016     8.0000
10017     8.0000
10018     7.0000
10019    12.0000
Name: Month.first.diagnosis, Length: 10020, dtype: float64
In [25]:
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '1'), 'Month.first.diagnosis'] = 0
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '2'), 'Month.first.diagnosis'] = .0909
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '3'), 'Month.first.diagnosis'] = .1818
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '4'), 'Month.first.diagnosis'] = .2727
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '5'), 'Month.first.diagnosis'] = .3636
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '6'), 'Month.first.diagnosis'] = .4545
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '7'), 'Month.first.diagnosis'] = .5454
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '8'), 'Month.first.diagnosis'] = .6363
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '9'), 'Month.first.diagnosis'] = .7272
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '10'), 'Month.first.diagnosis'] = .8181
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '11'), 'Month.first.diagnosis'] = 0.909
COVID_data.loc[(COVID_data['Month.first.diagnosis'] == '12'), 'Month.first.diagnosis'] = 1
In [26]:
COVID_data.loc[(COVID_data['Year.first.diagnosis'] == 2020), 'Year.first.diagnosis'] = 0
COVID_data.loc[(COVID_data['Year.first.diagnosis'] == 2021), 'Year.first.diagnosis'] = 1
In [27]:
# "N/A"s and "No"s are 0s, "Yes"s are 1s.

COVID_data.loc[(COVID_data['Uncomplicated.phase'] == 'n/a'), 'Uncomplicated.phase'] = 0
COVID_data.loc[(COVID_data['Uncomplicated.phase'] == 'no'), 'Uncomplicated.phase'] = 0
COVID_data.loc[(COVID_data['Uncomplicated.phase'] == 'yes'), 'Uncomplicated.phase'] = 1
In [28]:
COVID_data.loc[(COVID_data['Complicated.phase'] == 'n/a'), 'Complicated.phase'] = 0
COVID_data.loc[(COVID_data['Complicated.phase'] == 'no'), 'Complicated.phase'] = 0
COVID_data.loc[(COVID_data['Complicated.phase'] == 'yes'), 'Complicated.phase'] = 1
In [29]:
COVID_data.loc[(COVID_data['Critical.phase'] == 'n/a'), 'Critical.phase'] = 0
COVID_data.loc[(COVID_data['Critical.phase'] == 'no'), 'Critical.phase'] = 0
COVID_data.loc[(COVID_data['Critical.phase'] == 'yes'), 'Critical.phase'] = 1
In [30]:
COVID_data.loc[(COVID_data['Critical.phase'] == 'n/a'), 'Recovery.phase'] = 0
COVID_data.loc[(COVID_data['Recovery.phase'] == 'no'), 'Recovery.phase'] = 0
COVID_data.loc[(COVID_data['Recovery.phase'] == 'yes'), 'Recovery.phase'] = 1
In [31]:
COVID_data.loc[(COVID_data['Vasopressors.in.complicated.phase'] == 'n/a'), 'Vasopressors.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.complicated.phase'] == 'unknown/missing'), 'Vasopressors.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.complicated.phase'] == 'no'), 'Vasopressors.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.complicated.phase'] == 'yes'), 'Vasopressors.in.complicated.phase'] = 1
In [32]:
COVID_data.loc[(COVID_data['Vasopressors.in.critical.phase'] == 'n/a'), 'Vasopressors.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.critical.phase'] == 'unknown/missing'), 'Vasopressors.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.critical.phase'] == 'no'), 'Vasopressors.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Vasopressors.in.critical.phase'] == 'yes'), 'Vasopressors.in.critical.phase'] = 1
In [33]:
COVID_data.loc[(COVID_data['Invasive.ventilation.in.critical.phase'] == 'n/a'), 'Invasive.ventilation.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Invasive.ventilation.in.critical.phase'] == 'unknown/missing'), 'Invasive.ventilation.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Invasive.ventilation.in.critical.phase'] == 'no'), 'Invasive.ventilation.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Invasive.ventilation.in.critical.phase'] == 'yes'), 'Invasive.ventilation.in.critical.phase'] = 1
In [34]:
# All forms of superinfections are 1s.

COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'n/a'), 'Superinfection.in.uncomplicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'unknown/missing'), 'Superinfection.in.uncomplicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'none'), 'Superinfection.in.uncomplicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'bacterial'), 'Superinfection.in.uncomplicated.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'bacterial&fungal'), 'Superinfection.in.uncomplicated.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.uncomplicated.phase'] == 'fungal'), 'Superinfection.in.uncomplicated.phase'] = 1
In [35]:
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'n/a'), 'Superinfection.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'unknown/missing'), 'Superinfection.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'none'), 'Superinfection.in.complicated.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'bacterial'), 'Superinfection.in.complicated.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'bacterial&fungal'), 'Superinfection.in.complicated.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.complicated.phase'] == 'fungal'), 'Superinfection.in.complicated.phase'] = 1
In [36]:
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'n/a'), 'Superinfection.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'unknown/missing'), 'Superinfection.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'none'), 'Superinfection.in.critical.phase'] = 0
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'bacterial'), 'Superinfection.in.critical.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'bacterial&fungal'), 'Superinfection.in.critical.phase'] = 1
COVID_data.loc[(COVID_data['Superinfection.in.critical.phase'] == 'fungal'), 'Superinfection.in.critical.phase'] = 1
In [37]:
COVID_data.loc[(COVID_data['Symptoms.in.recovery.phase'] == 'n/a'), 'Symptoms.in.recovery.phase'] = 0
COVID_data.loc[(COVID_data['Symptoms.in.recovery.phase'] == 'unknown/missing'), 'Symptoms.in.recovery.phase'] = 0
COVID_data.loc[(COVID_data['Symptoms.in.recovery.phase'] == 'no'), 'Symptoms.in.recovery.phase'] = 0
COVID_data.loc[(COVID_data['Symptoms.in.recovery.phase'] == 'yes'), 'Symptoms.in.recovery.phase'] = 1
In [38]:
# Only "Dead from COVID-19" is a 1, all other instances are 0s.

COVID_data.loc[(COVID_data['Last.known.patient.status'] == 'Dead from COVID-19'), 'Last.known.patient.status'] = 1
COVID_data.loc[(COVID_data['Last.known.patient.status'] == 'Dead from other causes'), 'Last.known.patient.status'] = 0
COVID_data.loc[(COVID_data['Last.known.patient.status'] == 'Not recovered'), 'Last.known.patient.status'] = 0
COVID_data.loc[(COVID_data['Last.known.patient.status'] == 'Recovered'), 'Last.known.patient.status'] = 0
COVID_data.loc[(COVID_data['Last.known.patient.status'] == 'unknown/missing'), 'Last.known.patient.status'] = 0
In [39]:
COVID_data.head(20)
Out[39]:
Age.at.diagnosis Sex Month.first.diagnosis Year.first.diagnosis Uncomplicated.phase Complicated.phase Critical.phase Recovery.phase Vasopressors.in.complicated.phase Vasopressors.in.critical.phase Invasive.ventilation.in.critical.phase Superinfection.in.uncomplicated.phase Superinfection.in.complicated.phase Superinfection.in.critical.phase Symptoms.in.recovery.phase Last.known.patient.status
0 0.25 0 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
1 0.5 0 0.0909 0 1 1 0 1 0 0 0 0 0 0 1 0
2 0 0 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
3 0.5 1 0.0909 0 1 0 0 1 0 0 0 0 0 0 0 0
4 0.25 0 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
5 0.5 0 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
6 0.5 1 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
7 0.5 1 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
8 0.5 1 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
9 0.5 0 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
10 0.5 0 0.0909 0 1 1 0 0 0 0 0 0 0 0 0 0
11 0.5 1 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
12 0.25 1 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
13 0.25 0 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
14 0.25 1 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
15 0.25 1 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
16 0.25 1 0.0909 0 1 0 0 0 0 0 0 0 0 0 0 0
17 0.5 1 0.0909 0 1 0 1 0 0 1 1 0 0 1 0 0
18 0.5 1 0.0909 0 1 1 1 0 1 1 1 0 0 0 0 0
19 0.5 1 0.0909 0 1 0 1 0 0 1 1 0 0 0 0 0
In [40]:
# Create an alternate dataset of just patients who have entered the complicated phase.
## Use this to determine whether the model is generalizable or not.

complicated_data = COVID_data[COVID_data['Complicated.phase'] == 1]

Results:¶

Predict whether a patient will die from COVID-19.¶

In [41]:
# Use the full COVID-19 dataset.

COVID_data = COVID_data.drop(columns=['Recovery.phase','Symptoms.in.recovery.phase'])
In [42]:
# Create x and y variables, and train and test sets.

x = COVID_data
y = x.pop('Last.known.patient.status').to_frame()

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42)
x_train.head()
Out[42]:
Age.at.diagnosis Sex Month.first.diagnosis Year.first.diagnosis Uncomplicated.phase Complicated.phase Critical.phase Vasopressors.in.complicated.phase Vasopressors.in.critical.phase Invasive.ventilation.in.critical.phase Superinfection.in.uncomplicated.phase Superinfection.in.complicated.phase Superinfection.in.critical.phase
8998 0.5 0 1 0 1 1 1 0 1 1 0 1 1
9237 0.75 1 0.8181 0 0 1 1 0 1 1 0 0 0
3265 0.75 1 0.2727 0 1 1 0 0 0 0 0 1 0
9349 0.75 1 0.1818 1 1 1 0 0 0 0 0 0 0
5665 0.5 0 1 0 0 1 1 0 0 1 0 0 0
In [43]:
y_train.head()
Out[43]:
Last.known.patient.status
8998 1
9237 1
3265 0
9349 0
5665 1
In [44]:
y_train = tf.keras.utils.to_categorical(y_train.values, num_classes=2)
y_test = tf.keras.utils.to_categorical(y_test.values, num_classes=2)
In [45]:
s = StandardScaler()

x_train = s.fit_transform(x_train)
x_test = s.transform(x_test)
In [46]:
K.clear_session()
tf.random.set_seed(42)
In [47]:
model_class = Sequential()
model_class.add(Dense(100, activation='relu', input_shape=(x_train.shape[1],)))
model_class.add(Dropout(0.1))
model_class.add(Dense(50, activation='relu'))
model_class.add(BatchNormalization())
model_class.add(Dropout(0.1))
model_class.add(Dense(10, activation='relu'))
model_class.add(BatchNormalization())
model_class.add(Dropout(0.1))
model_class.add(Dense(5, activation='relu'))
model_class.add(Dense(2, activation='sigmoid'))
In [48]:
model_class.compile(loss = 'binary_crossentropy', optimizer='adamax', metrics=[AUC()])
model_class.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 100)               1400      
                                                                 
 dropout (Dropout)           (None, 100)               0         
                                                                 
 dense_1 (Dense)             (None, 50)                5050      
                                                                 
 batch_normalization (BatchN  (None, 50)               200       
 ormalization)                                                   
                                                                 
 dropout_1 (Dropout)         (None, 50)                0         
                                                                 
 dense_2 (Dense)             (None, 10)                510       
                                                                 
 batch_normalization_1 (Batc  (None, 10)               40        
 hNormalization)                                                 
                                                                 
 dropout_2 (Dropout)         (None, 10)                0         
                                                                 
 dense_3 (Dense)             (None, 5)                 55        
                                                                 
 dense_4 (Dense)             (None, 2)                 12        
                                                                 
=================================================================
Total params: 7,267
Trainable params: 7,147
Non-trainable params: 120
_________________________________________________________________
In [49]:
binary_class = model_class.fit(x_train, y_train,
                               validation_split=0.1,
                               epochs=50, verbose=2,
                               shuffle=True)
Epoch 1/50
212/212 - 2s - loss: 0.5600 - auc: 0.7987 - val_loss: 0.4688 - val_auc: 0.9319 - 2s/epoch - 10ms/step
Epoch 2/50
212/212 - 0s - loss: 0.3687 - auc: 0.9293 - val_loss: 0.3153 - val_auc: 0.9487 - 474ms/epoch - 2ms/step
Epoch 3/50
212/212 - 0s - loss: 0.3118 - auc: 0.9441 - val_loss: 0.2858 - val_auc: 0.9519 - 464ms/epoch - 2ms/step
Epoch 4/50
212/212 - 1s - loss: 0.2893 - auc: 0.9497 - val_loss: 0.2758 - val_auc: 0.9532 - 644ms/epoch - 3ms/step
Epoch 5/50
212/212 - 0s - loss: 0.2836 - auc: 0.9508 - val_loss: 0.2714 - val_auc: 0.9542 - 480ms/epoch - 2ms/step
Epoch 6/50
212/212 - 1s - loss: 0.2774 - auc: 0.9530 - val_loss: 0.2656 - val_auc: 0.9555 - 590ms/epoch - 3ms/step
Epoch 7/50
212/212 - 1s - loss: 0.2715 - auc: 0.9543 - val_loss: 0.2686 - val_auc: 0.9545 - 685ms/epoch - 3ms/step
Epoch 8/50
212/212 - 0s - loss: 0.2720 - auc: 0.9540 - val_loss: 0.2680 - val_auc: 0.9544 - 451ms/epoch - 2ms/step
Epoch 9/50
212/212 - 0s - loss: 0.2690 - auc: 0.9552 - val_loss: 0.2651 - val_auc: 0.9554 - 500ms/epoch - 2ms/step
Epoch 10/50
212/212 - 0s - loss: 0.2678 - auc: 0.9549 - val_loss: 0.2662 - val_auc: 0.9549 - 425ms/epoch - 2ms/step
Epoch 11/50
212/212 - 0s - loss: 0.2657 - auc: 0.9559 - val_loss: 0.2641 - val_auc: 0.9559 - 453ms/epoch - 2ms/step
Epoch 12/50
212/212 - 0s - loss: 0.2656 - auc: 0.9560 - val_loss: 0.2633 - val_auc: 0.9563 - 434ms/epoch - 2ms/step
Epoch 13/50
212/212 - 0s - loss: 0.2634 - auc: 0.9565 - val_loss: 0.2631 - val_auc: 0.9562 - 445ms/epoch - 2ms/step
Epoch 14/50
212/212 - 0s - loss: 0.2616 - auc: 0.9575 - val_loss: 0.2627 - val_auc: 0.9563 - 434ms/epoch - 2ms/step
Epoch 15/50
212/212 - 0s - loss: 0.2585 - auc: 0.9582 - val_loss: 0.2644 - val_auc: 0.9561 - 416ms/epoch - 2ms/step
Epoch 16/50
212/212 - 0s - loss: 0.2615 - auc: 0.9574 - val_loss: 0.2654 - val_auc: 0.9559 - 437ms/epoch - 2ms/step
Epoch 17/50
212/212 - 0s - loss: 0.2570 - auc: 0.9587 - val_loss: 0.2631 - val_auc: 0.9563 - 388ms/epoch - 2ms/step
Epoch 18/50
212/212 - 0s - loss: 0.2566 - auc: 0.9589 - val_loss: 0.2626 - val_auc: 0.9564 - 422ms/epoch - 2ms/step
Epoch 19/50
212/212 - 0s - loss: 0.2557 - auc: 0.9595 - val_loss: 0.2645 - val_auc: 0.9557 - 441ms/epoch - 2ms/step
Epoch 20/50
212/212 - 0s - loss: 0.2569 - auc: 0.9586 - val_loss: 0.2633 - val_auc: 0.9565 - 418ms/epoch - 2ms/step
Epoch 21/50
212/212 - 0s - loss: 0.2540 - auc: 0.9596 - val_loss: 0.2648 - val_auc: 0.9555 - 407ms/epoch - 2ms/step
Epoch 22/50
212/212 - 0s - loss: 0.2557 - auc: 0.9590 - val_loss: 0.2647 - val_auc: 0.9558 - 400ms/epoch - 2ms/step
Epoch 23/50
212/212 - 0s - loss: 0.2569 - auc: 0.9589 - val_loss: 0.2625 - val_auc: 0.9559 - 419ms/epoch - 2ms/step
Epoch 24/50
212/212 - 0s - loss: 0.2579 - auc: 0.9585 - val_loss: 0.2630 - val_auc: 0.9560 - 422ms/epoch - 2ms/step
Epoch 25/50
212/212 - 0s - loss: 0.2547 - auc: 0.9592 - val_loss: 0.2634 - val_auc: 0.9558 - 415ms/epoch - 2ms/step
Epoch 26/50
212/212 - 0s - loss: 0.2533 - auc: 0.9598 - val_loss: 0.2653 - val_auc: 0.9551 - 426ms/epoch - 2ms/step
Epoch 27/50
212/212 - 0s - loss: 0.2547 - auc: 0.9592 - val_loss: 0.2660 - val_auc: 0.9546 - 417ms/epoch - 2ms/step
Epoch 28/50
212/212 - 0s - loss: 0.2539 - auc: 0.9598 - val_loss: 0.2669 - val_auc: 0.9549 - 398ms/epoch - 2ms/step
Epoch 29/50
212/212 - 0s - loss: 0.2539 - auc: 0.9598 - val_loss: 0.2654 - val_auc: 0.9553 - 424ms/epoch - 2ms/step
Epoch 30/50
212/212 - 0s - loss: 0.2526 - auc: 0.9603 - val_loss: 0.2662 - val_auc: 0.9550 - 433ms/epoch - 2ms/step
Epoch 31/50
212/212 - 0s - loss: 0.2515 - auc: 0.9604 - val_loss: 0.2654 - val_auc: 0.9557 - 417ms/epoch - 2ms/step
Epoch 32/50
212/212 - 0s - loss: 0.2516 - auc: 0.9604 - val_loss: 0.2646 - val_auc: 0.9554 - 403ms/epoch - 2ms/step
Epoch 33/50
212/212 - 0s - loss: 0.2555 - auc: 0.9590 - val_loss: 0.2659 - val_auc: 0.9548 - 429ms/epoch - 2ms/step
Epoch 34/50
212/212 - 0s - loss: 0.2553 - auc: 0.9591 - val_loss: 0.2669 - val_auc: 0.9545 - 414ms/epoch - 2ms/step
Epoch 35/50
212/212 - 0s - loss: 0.2520 - auc: 0.9605 - val_loss: 0.2660 - val_auc: 0.9547 - 413ms/epoch - 2ms/step
Epoch 36/50
212/212 - 0s - loss: 0.2521 - auc: 0.9602 - val_loss: 0.2657 - val_auc: 0.9547 - 414ms/epoch - 2ms/step
Epoch 37/50
212/212 - 0s - loss: 0.2524 - auc: 0.9602 - val_loss: 0.2651 - val_auc: 0.9551 - 463ms/epoch - 2ms/step
Epoch 38/50
212/212 - 0s - loss: 0.2518 - auc: 0.9604 - val_loss: 0.2633 - val_auc: 0.9560 - 444ms/epoch - 2ms/step
Epoch 39/50
212/212 - 0s - loss: 0.2479 - auc: 0.9618 - val_loss: 0.2636 - val_auc: 0.9555 - 413ms/epoch - 2ms/step
Epoch 40/50
212/212 - 0s - loss: 0.2520 - auc: 0.9604 - val_loss: 0.2642 - val_auc: 0.9551 - 438ms/epoch - 2ms/step
Epoch 41/50
212/212 - 0s - loss: 0.2492 - auc: 0.9613 - val_loss: 0.2671 - val_auc: 0.9541 - 430ms/epoch - 2ms/step
Epoch 42/50
212/212 - 0s - loss: 0.2510 - auc: 0.9605 - val_loss: 0.2661 - val_auc: 0.9545 - 414ms/epoch - 2ms/step
Epoch 43/50
212/212 - 0s - loss: 0.2523 - auc: 0.9600 - val_loss: 0.2645 - val_auc: 0.9550 - 419ms/epoch - 2ms/step
Epoch 44/50
212/212 - 0s - loss: 0.2504 - auc: 0.9605 - val_loss: 0.2680 - val_auc: 0.9540 - 421ms/epoch - 2ms/step
Epoch 45/50
212/212 - 0s - loss: 0.2477 - auc: 0.9612 - val_loss: 0.2651 - val_auc: 0.9541 - 452ms/epoch - 2ms/step
Epoch 46/50
212/212 - 0s - loss: 0.2507 - auc: 0.9607 - val_loss: 0.2662 - val_auc: 0.9544 - 440ms/epoch - 2ms/step
Epoch 47/50
212/212 - 0s - loss: 0.2503 - auc: 0.9610 - val_loss: 0.2674 - val_auc: 0.9543 - 409ms/epoch - 2ms/step
Epoch 48/50
212/212 - 0s - loss: 0.2463 - auc: 0.9619 - val_loss: 0.2689 - val_auc: 0.9540 - 440ms/epoch - 2ms/step
Epoch 49/50
212/212 - 0s - loss: 0.2463 - auc: 0.9623 - val_loss: 0.2673 - val_auc: 0.9540 - 429ms/epoch - 2ms/step
Epoch 50/50
212/212 - 0s - loss: 0.2461 - auc: 0.9619 - val_loss: 0.2668 - val_auc: 0.9542 - 411ms/epoch - 2ms/step
In [50]:
plt.plot(binary_class.history['auc'])
plt.plot(binary_class.history['val_auc'])
plt.title('AUC vs Epochs')
plt.ylabel('AUC')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')
plt.show()
In [51]:
metrics.roc_auc_score(y_test, model_class.predict(x_test))
79/79 [==============================] - 0s 1ms/step
Out[51]:
0.9015763512614858

These results are promising. However, since the vast majority of patients with COVID-19 recover from mild cases, it is important to evaluate this model for patients who had more extreme COVID-19 cases. Therefore, I will be using this same model for only the patients who have entered the complicated phase.

Complicated data:¶

In [52]:
complicated_data = complicated_data.drop(columns=['Recovery.phase','Symptoms.in.recovery.phase'])
In [53]:
x_2 = complicated_data
y_2 = x_2.pop('Last.known.patient.status').to_frame()
In [54]:
x_train_2, x_test_2, y_train_2, y_test_2 = train_test_split(x_2, y_2, test_size=0.25, random_state=42)
x_train_2.head()
Out[54]:
Age.at.diagnosis Sex Month.first.diagnosis Year.first.diagnosis Uncomplicated.phase Complicated.phase Critical.phase Vasopressors.in.complicated.phase Vasopressors.in.critical.phase Invasive.ventilation.in.critical.phase Superinfection.in.uncomplicated.phase Superinfection.in.complicated.phase Superinfection.in.critical.phase
5140 0.5 1 0.8181 0 1 1 1 0 0 0 0 0 0
5175 0.5 1 0.909 0 1 1 1 0 1 1 0 0 1
4458 0.5 0 0.0909 0 1 1 0 0 0 0 1 1 0
2252 0.75 1 0.0909 0 1 1 0 0 0 0 0 0 0
7481 1 0 0.909 0 0 1 0 0 0 0 0 0 0
In [55]:
y_train_2 = tf.keras.utils.to_categorical(y_train_2.values, num_classes=2)
y_test_2 = tf.keras.utils.to_categorical(y_test_2.values, num_classes=2)
In [56]:
x_train_2 = s.fit_transform(x_train_2)
x_test_2 = s.transform(x_test_2)
In [57]:
K.clear_session()
tf.random.set_seed(42)
In [58]:
model_class_2 = Sequential()
model_class_2.add(Dense(100, activation='relu', input_shape=(x_train_2.shape[1],)))
model_class_2.add(BatchNormalization())
model_class_2.add(Dropout(0.1))
model_class_2.add(Dense(50, activation='relu'))
model_class_2.add(BatchNormalization())
model_class_2.add(Dropout(0.1))
model_class_2.add(Dense(10, activation='relu'))
model_class_2.add(BatchNormalization())
model_class_2.add(Dropout(0.1))
model_class_2.add(Dense(5, activation='relu'))
model_class_2.add(Dense(2, activation='sigmoid'))
In [59]:
model_class_2.compile(loss = 'binary_crossentropy', optimizer='adamax', metrics=[AUC()])
model_class_2.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 100)               1400      
                                                                 
 batch_normalization (BatchN  (None, 100)              400       
 ormalization)                                                   
                                                                 
 dropout (Dropout)           (None, 100)               0         
                                                                 
 dense_1 (Dense)             (None, 50)                5050      
                                                                 
 batch_normalization_1 (Batc  (None, 50)               200       
 hNormalization)                                                 
                                                                 
 dropout_1 (Dropout)         (None, 50)                0         
                                                                 
 dense_2 (Dense)             (None, 10)                510       
                                                                 
 batch_normalization_2 (Batc  (None, 10)               40        
 hNormalization)                                                 
                                                                 
 dropout_2 (Dropout)         (None, 10)                0         
                                                                 
 dense_3 (Dense)             (None, 5)                 55        
                                                                 
 dense_4 (Dense)             (None, 2)                 12        
                                                                 
=================================================================
Total params: 7,667
Trainable params: 7,347
Non-trainable params: 320
_________________________________________________________________
In [60]:
binary_class = model_class_2.fit(x_train_2, y_train_2,
                               validation_split=0.1,
                               epochs=50, verbose=2,
                               shuffle=True)
Epoch 1/50
100/100 - 2s - loss: 0.6812 - auc: 0.6434 - val_loss: 0.6026 - val_auc: 0.8822 - 2s/epoch - 17ms/step
Epoch 2/50
100/100 - 0s - loss: 0.5478 - auc: 0.8122 - val_loss: 0.5161 - val_auc: 0.9077 - 283ms/epoch - 3ms/step
Epoch 3/50
100/100 - 0s - loss: 0.4878 - auc: 0.8618 - val_loss: 0.4475 - val_auc: 0.9120 - 253ms/epoch - 3ms/step
Epoch 4/50
100/100 - 0s - loss: 0.4533 - auc: 0.8775 - val_loss: 0.4061 - val_auc: 0.9160 - 271ms/epoch - 3ms/step
Epoch 5/50
100/100 - 0s - loss: 0.4364 - auc: 0.8833 - val_loss: 0.3855 - val_auc: 0.9199 - 246ms/epoch - 2ms/step
Epoch 6/50
100/100 - 0s - loss: 0.4259 - auc: 0.8863 - val_loss: 0.3741 - val_auc: 0.9203 - 227ms/epoch - 2ms/step
Epoch 7/50
100/100 - 0s - loss: 0.4137 - auc: 0.8921 - val_loss: 0.3713 - val_auc: 0.9201 - 239ms/epoch - 2ms/step
Epoch 8/50
100/100 - 0s - loss: 0.4120 - auc: 0.8937 - val_loss: 0.3642 - val_auc: 0.9222 - 250ms/epoch - 3ms/step
Epoch 9/50
100/100 - 0s - loss: 0.4067 - auc: 0.8954 - val_loss: 0.3629 - val_auc: 0.9224 - 246ms/epoch - 2ms/step
Epoch 10/50
100/100 - 0s - loss: 0.3985 - auc: 0.8995 - val_loss: 0.3569 - val_auc: 0.9247 - 265ms/epoch - 3ms/step
Epoch 11/50
100/100 - 0s - loss: 0.3991 - auc: 0.8995 - val_loss: 0.3526 - val_auc: 0.9253 - 240ms/epoch - 2ms/step
Epoch 12/50
100/100 - 0s - loss: 0.4019 - auc: 0.8985 - val_loss: 0.3553 - val_auc: 0.9245 - 225ms/epoch - 2ms/step
Epoch 13/50
100/100 - 0s - loss: 0.3970 - auc: 0.8998 - val_loss: 0.3534 - val_auc: 0.9260 - 229ms/epoch - 2ms/step
Epoch 14/50
100/100 - 0s - loss: 0.3955 - auc: 0.9012 - val_loss: 0.3493 - val_auc: 0.9274 - 244ms/epoch - 2ms/step
Epoch 15/50
100/100 - 0s - loss: 0.3916 - auc: 0.9032 - val_loss: 0.3491 - val_auc: 0.9279 - 232ms/epoch - 2ms/step
Epoch 16/50
100/100 - 0s - loss: 0.3852 - auc: 0.9061 - val_loss: 0.3507 - val_auc: 0.9264 - 227ms/epoch - 2ms/step
Epoch 17/50
100/100 - 0s - loss: 0.3918 - auc: 0.9030 - val_loss: 0.3492 - val_auc: 0.9275 - 249ms/epoch - 2ms/step
Epoch 18/50
100/100 - 0s - loss: 0.3870 - auc: 0.9050 - val_loss: 0.3514 - val_auc: 0.9260 - 251ms/epoch - 3ms/step
Epoch 19/50
100/100 - 0s - loss: 0.3870 - auc: 0.9057 - val_loss: 0.3548 - val_auc: 0.9247 - 231ms/epoch - 2ms/step
Epoch 20/50
100/100 - 0s - loss: 0.3846 - auc: 0.9069 - val_loss: 0.3546 - val_auc: 0.9246 - 232ms/epoch - 2ms/step
Epoch 21/50
100/100 - 0s - loss: 0.3840 - auc: 0.9066 - val_loss: 0.3544 - val_auc: 0.9248 - 247ms/epoch - 2ms/step
Epoch 22/50
100/100 - 0s - loss: 0.3914 - auc: 0.9044 - val_loss: 0.3542 - val_auc: 0.9249 - 236ms/epoch - 2ms/step
Epoch 23/50
100/100 - 0s - loss: 0.3873 - auc: 0.9056 - val_loss: 0.3533 - val_auc: 0.9251 - 230ms/epoch - 2ms/step
Epoch 24/50
100/100 - 0s - loss: 0.3899 - auc: 0.9049 - val_loss: 0.3536 - val_auc: 0.9242 - 247ms/epoch - 2ms/step
Epoch 25/50
100/100 - 0s - loss: 0.3797 - auc: 0.9097 - val_loss: 0.3568 - val_auc: 0.9226 - 268ms/epoch - 3ms/step
Epoch 26/50
100/100 - 0s - loss: 0.3792 - auc: 0.9096 - val_loss: 0.3558 - val_auc: 0.9234 - 241ms/epoch - 2ms/step
Epoch 27/50
100/100 - 0s - loss: 0.3848 - auc: 0.9069 - val_loss: 0.3545 - val_auc: 0.9242 - 260ms/epoch - 3ms/step
Epoch 28/50
100/100 - 0s - loss: 0.3831 - auc: 0.9076 - val_loss: 0.3534 - val_auc: 0.9244 - 234ms/epoch - 2ms/step
Epoch 29/50
100/100 - 0s - loss: 0.3841 - auc: 0.9066 - val_loss: 0.3550 - val_auc: 0.9237 - 245ms/epoch - 2ms/step
Epoch 30/50
100/100 - 0s - loss: 0.3815 - auc: 0.9079 - val_loss: 0.3561 - val_auc: 0.9232 - 242ms/epoch - 2ms/step
Epoch 31/50
100/100 - 0s - loss: 0.3797 - auc: 0.9087 - val_loss: 0.3584 - val_auc: 0.9217 - 224ms/epoch - 2ms/step
Epoch 32/50
100/100 - 0s - loss: 0.3811 - auc: 0.9095 - val_loss: 0.3562 - val_auc: 0.9222 - 227ms/epoch - 2ms/step
Epoch 33/50
100/100 - 0s - loss: 0.3777 - auc: 0.9104 - val_loss: 0.3555 - val_auc: 0.9221 - 238ms/epoch - 2ms/step
Epoch 34/50
100/100 - 0s - loss: 0.3726 - auc: 0.9130 - val_loss: 0.3547 - val_auc: 0.9230 - 271ms/epoch - 3ms/step
Epoch 35/50
100/100 - 0s - loss: 0.3716 - auc: 0.9130 - val_loss: 0.3531 - val_auc: 0.9238 - 243ms/epoch - 2ms/step
Epoch 36/50
100/100 - 0s - loss: 0.3764 - auc: 0.9104 - val_loss: 0.3535 - val_auc: 0.9239 - 249ms/epoch - 2ms/step
Epoch 37/50
100/100 - 0s - loss: 0.3794 - auc: 0.9093 - val_loss: 0.3534 - val_auc: 0.9243 - 235ms/epoch - 2ms/step
Epoch 38/50
100/100 - 0s - loss: 0.3717 - auc: 0.9129 - val_loss: 0.3542 - val_auc: 0.9229 - 259ms/epoch - 3ms/step
Epoch 39/50
100/100 - 0s - loss: 0.3719 - auc: 0.9130 - val_loss: 0.3559 - val_auc: 0.9220 - 244ms/epoch - 2ms/step
Epoch 40/50
100/100 - 0s - loss: 0.3731 - auc: 0.9122 - val_loss: 0.3543 - val_auc: 0.9228 - 228ms/epoch - 2ms/step
Epoch 41/50
100/100 - 0s - loss: 0.3738 - auc: 0.9127 - val_loss: 0.3532 - val_auc: 0.9234 - 222ms/epoch - 2ms/step
Epoch 42/50
100/100 - 0s - loss: 0.3721 - auc: 0.9127 - val_loss: 0.3539 - val_auc: 0.9234 - 251ms/epoch - 3ms/step
Epoch 43/50
100/100 - 0s - loss: 0.3783 - auc: 0.9095 - val_loss: 0.3510 - val_auc: 0.9249 - 263ms/epoch - 3ms/step
Epoch 44/50
100/100 - 0s - loss: 0.3768 - auc: 0.9100 - val_loss: 0.3535 - val_auc: 0.9244 - 261ms/epoch - 3ms/step
Epoch 45/50
100/100 - 0s - loss: 0.3776 - auc: 0.9102 - val_loss: 0.3511 - val_auc: 0.9259 - 232ms/epoch - 2ms/step
Epoch 46/50
100/100 - 0s - loss: 0.3669 - auc: 0.9153 - val_loss: 0.3509 - val_auc: 0.9252 - 383ms/epoch - 4ms/step
Epoch 47/50
100/100 - 0s - loss: 0.3776 - auc: 0.9106 - val_loss: 0.3536 - val_auc: 0.9242 - 238ms/epoch - 2ms/step
Epoch 48/50
100/100 - 0s - loss: 0.3709 - auc: 0.9136 - val_loss: 0.3540 - val_auc: 0.9235 - 242ms/epoch - 2ms/step
Epoch 49/50
100/100 - 0s - loss: 0.3640 - auc: 0.9164 - val_loss: 0.3572 - val_auc: 0.9217 - 244ms/epoch - 2ms/step
Epoch 50/50
100/100 - 0s - loss: 0.3815 - auc: 0.9081 - val_loss: 0.3578 - val_auc: 0.9207 - 238ms/epoch - 2ms/step
In [61]:
plt.plot(binary_class.history['auc'])
plt.plot(binary_class.history['val_auc'])
plt.title('AUC vs Epochs')
plt.ylabel('AUC')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')
plt.show()
In [62]:
metrics.roc_auc_score(y_test_2, model_class_2.predict(x_test_2))
37/37 [==============================] - 0s 1ms/step
Out[62]:
0.8311267130341795

Conclusions & Bonus:¶

The training AUC is a bit lower for the complicated dataset, but performance was still pretty good. COVID-19 is an unpredictable virus, so it is unreasonable to expect 100% performance. However, although the model seems to perform well, it is hard to understand how exactly it works. Thus, I used Shapley values and surrogate decision trees to interpret it.

Shapley Values¶

In [63]:
x = np.asarray(x).astype('float32')
In [64]:
background = x[np.random.choice(x.shape[0],100, replace=False)]
In [65]:
import shap
import warnings
warnings.filterwarnings('ignore')

shap.explainers._deep.deep_tf.op_handlers["AddV2"] = shap.explainers._deep.deep_tf.passthrough

explainer = shap.DeepExplainer(model_class, background)
shap_values = explainer.shap_values(x_train)
expected_value = explainer.expected_value
In [66]:
shap.summary_plot(shap_values, x_train, title="Summary Plot") 

The numbers correspond to the variables as follows:

Feature 0: Age

Feature 1: Sex

Feature 2: Month

Feature 3: Year

Feature 4: Uncomplicated phase

Feature 5: Complicated phase

Feature 6: Critical phase

Feature 7: Recovery phase

Feature 8: Vasopressors in complicated phase

Feature 9: Vasopressors in critical phase

Feature 10: Invasive ventilation in critical phase

Feature 11: Superinfection in uncomplicated phase

Feature 12: Superinfection in complicated phase

Feature 13: Superinfection in critical phase

Feature 14: Symptoms in recovery phase

According to the SHAP plot, the three most heavily weighted variables are therefore the patient's age (0), whether the patient entered the critical phase (6), and whether the patient needed vasopressors in the critical phase (9).

Shapley values do not necessarily indicate that these features are the most important. Rather, they have the strongest impact on the model, but may actually be diminishing its performance. However, these findings make sense given real-world observations. I will also be using a surrogate decision tree to evaluate feature importances.

Surrogate Decision Tree¶

In [67]:
import dalex as dx

explainer = dx.Explainer(model_class, x, y, label='Status')
Preparation of a new explainer is initiated

  -> data              : numpy.ndarray converted to pandas.DataFrame. Columns are set as string numbers.
  -> data              : 10020 rows 13 cols
  -> target variable   : Parameter 'y' was a pandas.DataFrame. Converted to a numpy.ndarray.
  -> target variable   : 10020 values
  -> model_class       : keras.engine.sequential.Sequential (default)
  -> label             : Status
  -> predict function  : <function yhat_tf_classification at 0x000001800CE43B80> will be used (default)
1/1 [==============================] - 0s 31ms/step
  -> predict function  : Accepts pandas.DataFrame and numpy.ndarray.
314/314 [==============================] - 1s 2ms/step
  -> predicted values  : min = 0.0191, mean = 0.211, max = 0.705
  -> model type        : classification will be used (default)
  -> residual function : difference between y and yhat (default)
314/314 [==============================] - 1s 2ms/step
  -> residuals         :  'residual_function' returns an Error when executed:
'float' object has no attribute 'astype'
  -> model_info        : package keras

A new explainer has been created!
In [68]:
explainer.model_parts().plot()

## These seem to correspond with the Shapley values.
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 3ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 3ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 3ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 953us/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 862us/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 997us/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 2ms/step
32/32 [==============================] - 0s 1ms/step
In [69]:
surrogate_model = explainer.model_surrogate(max_vars=3, max_depth=3)
surrogate_model.performance
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 940us/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 3ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 3ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 709us/step
16/16 [==============================] - 0s 3ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 690us/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 795us/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 3ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 806us/step
16/16 [==============================] - 0s 3ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 3ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 876us/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 689us/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 759us/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 3ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 3ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 917us/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 3ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 759us/step
16/16 [==============================] - 0s 992us/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 780us/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 794us/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 2ms/step
16/16 [==============================] - 0s 1ms/step
16/16 [==============================] - 0s 2ms/step
Out[69]:
recall precision f1 accuracy auc
DecisionTreeClassifier 0.767296 0.627787 0.690566 0.967265 0.977021
In [70]:
surrogate_model.plot()

The decision tree surrogate model uses features 9 (whether vasopressors are needed in the critical phase), 0 (age), and 6 (whether the patient entered the critical phase) to determine if a patient will survive COVID-19. An older patient who needs vasopressors in the critical phase is the most likely to die of the virus according to this model. In all other cases, the model predicts that the individual will likely recover from the virus.

Final Conclusion¶

The analysis above corresponds with real-world observations. Risk of dying from COVID-19 increases substantially with age, and patients who require medical treatment during the critical phase (e.g. vasopressors) are less likely to survive.

There is currently no surefire way to predict whether a person will die from COVID-19, since the virus itself is very unpredictable and there are many other underlying risk factors at play. However, this model performs well given what limited information it is based on. Adding more data (e.g. socioeconomic factors like access to healthcare, health conditions like asthma or heart disease, and the variant of COVID-19 the patient was infected with) will likely improve its performance. Again, it is unreasonable to expect 100% accuracy, but a generalizable, highly accurate model could be immensely beneficial for public health.